import java.util.ArrayList;
import java.util.List;

public class Solution863 {

    class TreeNode {
        int val;
        TreeNode left;
        TreeNode right;
        TreeNode() {}
        TreeNode(int val) { this.val = val; }
        TreeNode(int val, TreeNode left, TreeNode right) {
            this.val = val;
            this.left = left;
            this.right = right;
        }
    }

    void travel(TreeNode[] map, TreeNode lastNode, TreeNode tmpNode){
        if(tmpNode == null){
            return;
        }
        map[tmpNode.val] = lastNode;
        travel(map, tmpNode, tmpNode.left);
        travel(map, tmpNode, tmpNode.right);
    }

    void threeDirectTravel(List<Integer> res, TreeNode[] map, boolean[] traveled, TreeNode tmpNode, int distance, int k){
        if(tmpNode == null || traveled[tmpNode.val] || distance > k){
            return;
        }
        traveled[tmpNode.val] = true;
        if(distance == k){
            res.add(tmpNode.val);
            return;
        }
        threeDirectTravel(res, map, traveled, tmpNode.left, distance+1, k);
        threeDirectTravel(res, map, traveled, tmpNode.right, distance+1, k);
        threeDirectTravel(res, map, traveled, map[tmpNode.val], distance+1, k);
    }

    public List<Integer> distanceK(TreeNode root, TreeNode target, int k) {
        List<Integer> res = new ArrayList<>();
        TreeNode[] map = new TreeNode[501];
        travel(map, null, root);
        boolean[] traveled = new boolean[501];
        threeDirectTravel(res, map, traveled, target, 0, k);
        return res;
    }
}
